{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "from wfdb import io, plot\n",
    "import wfdb\n",
    "import os\n",
    "import gc\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "%matplotlib notebook\n",
    "import pandas as pd\n",
    "import math\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense\n",
    "from keras.layers import Dropout, Input\n",
    "from keras.layers import CuDNNLSTM, LSTM\n",
    "from keras.callbacks import ModelCheckpoint\n",
    "from sklearn.utils import shuffle\n",
    "from sklearn.metrics import confusion_matrix\n",
    "import time\n",
    "import keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "def comments_to_dict(comments):\n",
    "    key_value_pairs = [comment.split(':') for comment in comments]\n",
    "    return {pair[0]: pair[1] for pair in key_value_pairs}\n",
    "\n",
    "def record_to_row(record, patient_id):\n",
    "    row = {}\n",
    "    row['patient'] = patient_id\n",
    "    row['name'] = record.record_name\n",
    "    row['label'] = comments_to_dict(record.comments)['Reason for admission'][1:]\n",
    "    row['signals'] = record.p_signal\n",
    "    row['signal_length'] = record.sig_len\n",
    "    channels = record.sig_name\n",
    "    signals = record.p_signal.transpose()\n",
    "    \n",
    "    row['channels'] = channels\n",
    "    \n",
    "    for channel, signal in zip(channels, signals):\n",
    "        row[channel] = signal\n",
    "        \n",
    "    return row\n",
    "\n",
    "def make_set(df_data, channels, label_map, record_id, window_size=2048):\n",
    "    n_windows = 0\n",
    "    \n",
    "    for _, record in tqdm(df_data.iterrows()):\n",
    "        n_windows+= record['signal_length']//window_size\n",
    "\n",
    "    dataX = np.zeros((n_windows, len(channels), window_size))\n",
    "    dataY = np.zeros((n_windows, len(label_map)))\n",
    "    \n",
    "    record_list = []\n",
    "    \n",
    "    nth_window = 0\n",
    "    for i, (patient, record) in enumerate(tqdm(df_data.iterrows())):\n",
    "        # read the record, get the signal data and transpose it\n",
    "        signal_data = io.rdrecord(os.path.join('ptb-diagnostic-ecg-database-1.0.0', record['name'])).p_signal.transpose()\n",
    "        n_rows = signal_data.shape[-1]\n",
    "        n_windows = n_rows//window_size\n",
    "        dataX[nth_window:nth_window+n_windows] = np.array([signal_data[:,i*window_size:(i+1)*window_size] for i in range(n_windows)])\n",
    "        dataY[nth_window:nth_window+n_windows][:, label_map[record.label]] = 1\n",
    "        nth_window+=n_windows\n",
    "        \n",
    "        if record_id:\n",
    "            record_list+= n_windows*[record['name']]\n",
    "        \n",
    "    return dataX, dataY, record_list\n",
    "\n",
    "record_names = io.get_record_list('ptbdb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a41d6d7658a6473eb7efd75533dac622",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=549), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "records = []\n",
    "for record_name in tqdm(record_names):\n",
    "    record = io.rdrecord(record_name=os.path.join('ptb-diagnostic-ecg-database-1.0.0', record_name))\n",
    "    label = comments_to_dict(record.comments)['Reason for admission'][1:]\n",
    "    patient = record_name.split('/')[0]\n",
    "    signal_length = record.sig_len\n",
    "    records.append({'name':record_name, 'label':label, 'patient':patient, 'signal_length':signal_length})\n",
    "    \n",
    "channels = record.sig_name\n",
    "df_records = pd.DataFrame(records)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Myocardial infarction     368\n",
       "Healthy control            80\n",
       "n/a                        27\n",
       "Cardiomyopathy             17\n",
       "Bundle branch block        17\n",
       "Dysrhythmia                16\n",
       "Hypertrophy                 7\n",
       "Valvular heart disease      6\n",
       "Myocarditis                 4\n",
       "Stable angina               2\n",
       "Unstable angina             1\n",
       "Heart failure (NYHA 2)      1\n",
       "Palpitation                 1\n",
       "Heart failure (NYHA 3)      1\n",
       "Heart failure (NYHA 4)      1\n",
       "Name: label, dtype: int64"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels = df_records['label'].unique()\n",
    "df_records['label'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_labels = [\n",
    "    'Healthy control',\n",
    "    'Myocardial infarction'\n",
    "    ]\n",
    "df_selected = df_records.loc[df_records['label'].isin(selected_labels)]\n",
    "label_map = {label: value for label, value in zip(selected_labels, range(len(selected_labels)))}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_patients = []\n",
    "train_patients = []\n",
    "test_size = 0.2\n",
    "channels\n",
    "for label in selected_labels:\n",
    "    df_selected = df_records.loc[df_records['label'] == label]\n",
    "    patients = df_selected['patient'].unique()\n",
    "    n_test = math.ceil(len(patients)*test_size)\n",
    "    test_patients+=list(np.random.choice(patients, n_test, replace=False))\n",
    "    train_patients+=list(patients[np.isin(patients, test_patients, invert=True)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c18e4c7cc723445a903cc17e08203ac2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "238888ceb51945658a264d179ed11e05",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "df_patient_records = df_records.set_index('patient')\n",
    "df_train_patients = df_patient_records.loc[train_patients]\n",
    "df_test_patients = df_patient_records.loc[test_patients]\n",
    "window_size = 2048#df_records['signal_length'].min()\n",
    "#trainX, trainY, _ = make_set(df_train_patients, channels, label_map, False, window_size)\n",
    "testX, testY, record_list = make_set(df_test_patients, channels, label_map, True, window_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_model(input_shape, output_dim, lstm_layer, dropout=0.2):\n",
    "    print(\"model dim: \", input_shape, output_dim)\n",
    "    model = Sequential()\n",
    "    model.add(lstm_layer(256, return_sequences=True, input_shape=input_shape, batch_size=None))\n",
    "    model.add(Dropout(dropout))\n",
    "    model.add(lstm_layer(128, return_sequences=True))\n",
    "    model.add(Dropout(dropout))\n",
    "    model.add(LSTM(64))\n",
    "    model.add(Dropout(dropout))\n",
    "    model.add(Dense(output_dim, activation='softmax'))\n",
    "    model.compile(loss='categorical_crossentropy', optimizer='adam')\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TimeHistory(keras.callbacks.Callback):\n",
    "    def on_train_begin(self, logs={}):\n",
    "        self.times = []\n",
    "\n",
    "    def on_epoch_begin(self, batch, logs={}):\n",
    "        self.epoch_time_start = time.time()\n",
    "\n",
    "    def on_epoch_end(self, batch, logs={}):\n",
    "        self.times.append(time.time() - self.epoch_time_start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(1337)\n",
    "test_patients = []\n",
    "train_patients = []\n",
    "test_size = 0.2\n",
    "channels\n",
    "for label in selected_labels:\n",
    "    df_selected = df_records.loc[df_records['label'] == label]\n",
    "    patients = df_selected['patient'].unique()\n",
    "    n_test = math.ceil(len(patients)*test_size)\n",
    "    test_patients+=list(np.random.choice(patients, n_test, replace=False))\n",
    "    train_patients+=list(patients[np.isin(patients, test_patients, invert=True)])\n",
    "    \n",
    "df_patient_records = df_records.set_index('patient')\n",
    "df_train_patients = df_patient_records.loc[train_patients]\n",
    "df_test_patients = df_patient_records.loc[test_patients]\n",
    "window_size = 2048#df_records['signal_length'].min()\n",
    "trainX, trainY, _ = make_set(df_train_patients, channels, label_map, False, window_size)\n",
    "testX, testY, record_list = make_set(df_test_patients, channels, label_map, True, window_size)\n",
    "\n",
    "#Shuffle order of train set\n",
    "trainX, trainY = shuffle(trainX, trainY)\n",
    "\n",
    "#Since we have a large class inbalance we need to udjust the weights for it.\n",
    "fractions = 1-trainY.sum(axis=0)/len(trainY)\n",
    "weights = fractions[trainY.argmax(axis=1)]\n",
    "\n",
    "#df_selected['patient'].sample(len())\n",
    "\n",
    "filepath = os.path.join('models', \"weights-improvement-{epoch:02d}-bigger.hdf5\")\n",
    "checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')\n",
    "\n",
    "model_name = 'two_classes'\n",
    "model_folder = os.path.join('tensorlogs', model_name + \"-logs/\")\n",
    "\n",
    "if not os.path.isdir(model_folder):\n",
    "    n_logs = 0\n",
    "else:\n",
    "    n_logs = len(os.listdir(model_folder))\n",
    "    \n",
    "tensorboard_logs = os.path.join(model_folder, \"%inth_run\"%n_logs)\n",
    "tensorboard_callback = keras.callbacks.TensorBoard(log_dir=tensorboard_logs, write_graph=False)\n",
    "time_callback = TimeHistory()\n",
    "callbacks = [checkpoint, tensorboard_callback, time_callback]\n",
    "\n",
    "model = make_model((trainX.shape[1], trainX.shape[2]), trainY.shape[-1], CuDNNLSTM)\n",
    "\n",
    "model.fit(trainX, trainY, epochs=50, batch_size=512, sample_weight=weights, callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
